'''
Function:
	吃豆豆小游戏
'''
import os
import sys
import pygame
import Levels

'''
手势识别
'''
import cv2
import imutils
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image


'''定义一些必要的参数'''
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
BLUE = (0, 0, 255)
GREEN = (0, 255, 0)
RED = (255, 0, 0)
YELLOW = (255, 255, 0)
PURPLE = (255, 0, 255)
SKYBLUE = (0, 191, 255)
BGMPATH = os.path.join(os.getcwd(), 'resources/sounds/bg.mp3')
ICONPATH = os.path.join(os.getcwd(), 'resources/images/icon.png')
FONTPATH = os.path.join(os.getcwd(), 'resources/font/SimHei.ttf')
HEROPATH = os.path.join(os.getcwd(), 'resources/images/pacman.png')
BlinkyPATH = os.path.join(os.getcwd(), 'resources/images/Blinky.png')
ClydePATH = os.path.join(os.getcwd(), 'resources/images/Clyde.png')
InkyPATH = os.path.join(os.getcwd(), 'resources/images/Inky.png')
PinkyPATH = os.path.join(os.getcwd(), 'resources/images/Pinky.png')


'''初始化'''
def initialize():
    pygame.init()
    icon_image = pygame.image.load(ICONPATH)
    pygame.display.set_icon(icon_image)
    screen = pygame.display.set_mode([606, 606])
    pygame.display.set_caption('手势识别吃豆人')
    return screen

'''显示文字'''
def showText(screen, font, is_clearance, flag=False):
    clock = pygame.time.Clock()
    msg = r'游戏结束!' if not is_clearance else r'恭喜，您赢了'
    positions = [[235, 233], [200, 303], [200, 333]] if not is_clearance else [[145, 233], [65, 303], [170, 333]]
    surface = pygame.Surface((400, 200))
    surface.set_alpha(10)#透明度
    surface.fill((128, 128, 128))
    screen.blit(surface, (100, 200))
    texts = [font.render(msg, True, WHITE),
             font.render(r'按回车再来一遍', True, WHITE),
             font.render(r'按ESC退出', True, WHITE)]
    while True:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                sys.exit()
                pygame.quit()
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_RETURN:
                    if is_clearance:
                        if not flag:
                            return
                        else:
                            main(initialize())
                    else:
                        main(initialize())
                elif event.key == pygame.K_ESCAPE:
                    sys.exit()
                    pygame.quit()
        for idx, (text, position) in enumerate(zip(texts, positions)):
            screen.blit(text, position)
        pygame.display.flip()
        clock.tick(10)

'''开始游戏'''
def startLevelGame(level, screen, font):
    clock = pygame.time.Clock()
    SCORE = 0
    wall_sprites = level.setupWalls(SKYBLUE)
    gate_sprites = level.setupGate(WHITE)
    hero_sprites, ghost_sprites = level.setupPlayers(HEROPATH, [BlinkyPATH, ClydePATH, InkyPATH, PinkyPATH])
    food_sprites = level.setupFood(YELLOW, WHITE)
    is_clearance = False

    '''手势'''
    aWeight = 0.5

    camera = cv2.VideoCapture(0)

    top, right, bottom, left = 90, 380, 285, 590

    num_frames = 0
    thresholded = None

    count = 0

    while True:
        (grabbed, frame) = camera.read()
        if grabbed:

            frame = imutils.resize(frame, width=700)

            frame = cv2.flip(frame, 1)

            clone = frame.copy()

            (height, width) = frame.shape[:2]

            roi = frame[top:bottom, right:left]

            gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
            gray = cv2.GaussianBlur(gray, (7, 7), 0)

            if num_frames < 30:
                run_avg(gray, aWeight)
            else:
                hand = segment(gray)

                if hand is not None:
                    (thresholded, segmented) = hand

                    cv2.drawContours(
                        clone, [segmented + (right, top)], -1, (0, 0, 255))

            cv2.rectangle(clone, (left, top), (right, bottom), (0, 255, 0), 2)

            num_frames += 1

            cv2.imshow('Video Feed', clone)
            if not thresholded is None:
                cv2.imshow('Thesholded', thresholded)
            keypress = cv2.waitKey(1) & 0xFF

            if keypress == ord('q'):
                sys.exit(-1)
                pygame.quit()
                break

            if count==0:
                if keypress == ord('s'):
                    print('12313')
                    gray_image = thresholded
                    input_tensor = transform(gray_image)
                    input_batch = input_tensor.unsqueeze(0)  # 添加批次维度
                    with torch.no_grad():
                        output = model(input_batch)
                    _, predicted_idx = torch.max(output, 1)
                    predicted_label = predicted_idx.item()

                    # 打印预测结果
                    print("预测结果:", classes[predicted_label])
                    count+=1
            else:
                gray_image = thresholded
                input_tensor = transform(gray_image)
                input_batch = input_tensor.unsqueeze(0)  # 添加批次维度
                with torch.no_grad():
                    output = model(input_batch)
                _, predicted_idx = torch.max(output, 1)
                predicted_label = predicted_idx.item()

                # 打印预测结果
                print("预测结果:", classes[predicted_label])


                if classes[predicted_label] == 'left':
                    for hero in hero_sprites:
                        hero.changeSpeed([-1, 0])
                        hero.is_move = True
                elif classes[predicted_label] == 'right':
                    for hero in hero_sprites:
                        hero.changeSpeed([1, 0])
                        hero.is_move = True
                elif classes[predicted_label] == 'up':
                    for hero in hero_sprites:
                        hero.changeSpeed([0, -1])
                        hero.is_move = True
                elif classes[predicted_label] == 'down':
                    for hero in hero_sprites:
                        hero.changeSpeed([0, 1])
                        hero.is_move = True

                # if (classes[predicted_label] == 'left') or (classes[predicted_label] == 'right') or (classes[predicted_label] == 'up') or (
                #         classes[predicted_label] == 'down'):
                #     hero.is_move = False


                screen.fill(BLACK)
                for hero in hero_sprites:
                    hero.update(wall_sprites, gate_sprites)
                hero_sprites.draw(screen)
                # print('1111111')
                for hero in hero_sprites:
                    food_eaten = pygame.sprite.spritecollide(hero, food_sprites, True)
                SCORE += len(food_eaten)
                wall_sprites.draw(screen)
                gate_sprites.draw(screen)
                food_sprites.draw(screen)
                for ghost in ghost_sprites:
                    # 指定幽灵运动路径
                    #选路
                    if ghost.tracks_loc[1] < ghost.tracks[ghost.tracks_loc[0]][2]:
                        ghost.changeSpeed(ghost.tracks[ghost.tracks_loc[0]][0: 2])
                        ghost.tracks_loc[1] += 1
                    else:
                        if ghost.tracks_loc[0] < len(ghost.tracks) - 1:
                            ghost.tracks_loc[0] += 1    #走下一段路
                        elif ghost.role_name == 'Clyde':    #特殊处理，从2号继续
                            ghost.tracks_loc[0] = 2
                        else:
                            ghost.tracks_loc[0] = 0     #遍历完一般重新遍历
                        ghost.changeSpeed(ghost.tracks[ghost.tracks_loc[0]][0: 2])
                        ghost.tracks_loc[1] = 0
                    #前进
                    if ghost.tracks_loc[1] < ghost.tracks[ghost.tracks_loc[0]][2]:
                        ghost.changeSpeed(ghost.tracks[ghost.tracks_loc[0]][0: 2])
                    else:
                        if ghost.tracks_loc[0] < len(ghost.tracks) - 1:
                            loc0 = ghost.tracks_loc[0] + 1
                        elif ghost.role_name == 'Clyde':
                            loc0 = 2
                        else:
                            loc0 = 0
                        ghost.changeSpeed(ghost.tracks[loc0][0: 2])
                    ghost.update(wall_sprites, None)

                ghost_sprites.draw(screen)
                score_text = font.render("Score: %s" % SCORE, True, RED)
                screen.blit(score_text, [10, 10])
                if len(food_sprites) == 0:
                    is_clearance = True
                    return is_clearance
                    break
                if pygame.sprite.groupcollide(hero_sprites, ghost_sprites, False, False):
                    is_clearance = False
                    return is_clearance
                    break
                pygame.display.flip()
                clock.tick(5)

'''主函数'''
def main(screen):
    pygame.mixer.init()
    pygame.mixer.music.load(BGMPATH)
    pygame.mixer.music.play(-1, 0.0)
    pygame.font.init()
    font_small = pygame.font.Font(FONTPATH, 18)
    font_big = pygame.font.Font(FONTPATH, 24)
    for num_level in range(1, Levels.NUMLEVELS + 1):
        if num_level == 1:
            level = Levels.Level1()
            is_clearance = startLevelGame(level, screen, font_small)
            if num_level == Levels.NUMLEVELS:
                showText(screen, font_big, is_clearance, True)
            else:
                showText(screen, font_big, is_clearance)

bg = None

classes =['down','left','pause','right','up',]

model = torch.load('model.pt',map_location='cpu')

transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化张量
])


def run_avg(image, aWeight):
    global bg
    if bg is None:
        bg = image.copy().astype('float')
        return

    cv2.accumulateWeighted(image, bg, aWeight)


def segment(image, threshold=25):
    global bg
    diff = cv2.absdiff(bg.astype('uint8'), image)

    thresholded = cv2.threshold(diff,
                                threshold,
                                255,
                                cv2.THRESH_BINARY)[1]

    (cnts, _) = cv2.findContours(thresholded.copy(),
                                 cv2.RETR_EXTERNAL,
                                 cv2.CHAIN_APPROX_SIMPLE)

    if len(cnts) == 0:
        return
    else:
        segmented = max(cnts, key=cv2.contourArea)
        return (thresholded, segmented)

'''test'''
if __name__ == '__main__':
    main(initialize())
